-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir] Enable disabling folding in dialect conversion #152890
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Previously this only happened post checking if the op isn't legal, but was done unconditionally post (and before other legalization patterns). Add option to not attempt folding and one to do so as last resort.
@llvm/pr-subscribers-mlir Author: Jacques Pienaar (jpienaar) ChangesPreviously this only happened post checking if the op is legal, but was done unconditionally post (and before other legalization patterns). Add option to not attempt folding and one to do so as last resort. Did consider but did not add a always attempt to fold option (which would have folded whether or not legal), but removed TODO about it. Full diff: https://github.com/llvm/llvm-project/pull/152890.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f6437657c9a93..da092fb24c4ff 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1158,6 +1158,16 @@ class PDLConversionConfig final {
// ConversionConfig
//===----------------------------------------------------------------------===//
+/// An enum to control folding behavior during dialect conversion.
+enum class DialectConversionFoldingMode {
+ /// Never attempt to fold.
+ Never,
+ /// Only attempt to fold not legal operations before applying patterns.
+ BeforePatterns,
+ /// Only attempt to fold not legal operations after applying patterns.
+ AfterPatterns,
+};
+
/// Dialect conversion configuration.
struct ConversionConfig {
/// An optional callback used to notify about match failure diagnostics during
@@ -1240,6 +1250,10 @@ struct ConversionConfig {
/// your patterns do not trigger any IR rollbacks. For details, see
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
bool allowPatternRollback = true;
+
+ /// The folding mode to use during conversion.
+ DialectConversionFoldingMode foldingMode =
+ DialectConversionFoldingMode::BeforePatterns;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 08803e082b057..4aa934be2abc9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2197,15 +2197,16 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
- // If the operation isn't legal, try to fold it in-place.
- // TODO: Should we always try to do this, even if the op is
- // already legal?
- if (succeeded(legalizeWithFold(op, rewriter))) {
- LLVM_DEBUG({
- logSuccess(logger, "operation was folded");
- logger.startLine() << logLineComment;
- });
- return success();
+ // If the operation is not legal, try to fold it in-place if the folding mode
+ // is 'BeforePatterns'. 'Never' will skip this.
+ if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
+ if (succeeded(legalizeWithFold(op, rewriter))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
}
// Otherwise, we need to apply a legalization pattern to this operation.
@@ -2217,6 +2218,18 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
+ // If the operation can't be legalized via patterns, try to fold it in-place
+ // if the folding mode is 'AfterPatterns'.
+ if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
+ if (succeeded(legalizeWithFold(op, rewriter))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
+ }
+
LLVM_DEBUG({
logFailure(logger, "no matched legalization pattern");
logger.startLine() << logLineComment;
diff --git a/mlir/test/Transforms/test-legalizer-fold-after.mlir b/mlir/test/Transforms/test-legalizer-fold-after.mlir
new file mode 100644
index 0000000000000..7f80252dc9604
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-fold-after.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=after-patterns" | FileCheck %s
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+ // CHECK-NOT: op_in_place_self_fold
+ // CHECK: 97
+ %1 = "test.op_in_place_self_fold"() : () -> (i32)
+ "test.return"(%1) : (i32) -> ()
+}
diff --git a/mlir/test/Transforms/test-legalizer-fold-before.mlir b/mlir/test/Transforms/test-legalizer-fold-before.mlir
new file mode 100644
index 0000000000000..fe6e29351a5d7
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-fold-before.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=before-patterns" | FileCheck %s
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+ // CHECK: op_in_place_self_fold
+ // CHECK-SAME: folded
+ %1 = "test.op_in_place_self_fold"() : () -> (i32)
+ "test.return"(%1) : (i32) -> ()
+}
diff --git a/mlir/test/Transforms/test-legalizer-no-fold.mlir b/mlir/test/Transforms/test-legalizer-no-fold.mlir
new file mode 100644
index 0000000000000..c2d4dff2b4d3d
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-no-fold.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -test-legalize-patterns="test-legalize-folding-mode=never" | FileCheck %s
+
+// CHECK-LABEL: @remove_foldable_op(
+func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
+ // Check that op was not folded.
+ // CHECK: "test.op_with_region_fold"
+ %0 = "test.op_with_region_fold"(%arg0) ({
+ "foo.op_with_region_terminator"() : () -> ()
+ }) : (i32) -> (i32)
+ "test.return"(%0) : (i32) -> ()
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2eaad552a7a3a..fee05618c1d58 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1478,6 +1478,8 @@ def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
let results = (outs I32);
let hasFolder = 1;
}
+def : Pat<(TestOpInPlaceSelfFold:$op $_),
+ (TestOpConstant ConstantAttr<I32Attr, "97">)>;
// Test op that simply returns success.
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index eda618f5b09c6..969fef928b88b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1402,7 +1402,7 @@ struct TestLegalizePatternDriver
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
- TerminatorOp, OneRegionOp>();
+ TerminatorOp, OneRegionOp, TestOpConstant>();
target.addLegalOp(OperationName("test.legal_op", &getContext()));
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
@@ -1457,6 +1457,7 @@ struct TestLegalizePatternDriver
DumpNotifications dumpNotifications;
config.listener = &dumpNotifications;
config.unlegalizedOps = &unlegalizedOps;
+ config.foldingMode = foldingMode;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config))) {
getOperation()->emitRemark() << "applyPartialConversion failed";
@@ -1476,6 +1477,7 @@ struct TestLegalizePatternDriver
ConversionConfig config;
DumpNotifications dumpNotifications;
+ config.foldingMode = foldingMode;
config.listener = &dumpNotifications;
if (failed(applyFullConversion(getOperation(), target,
std::move(patterns), config))) {
@@ -1490,6 +1492,7 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
ConversionConfig config;
+ config.foldingMode = foldingMode;
config.legalizableOps = &legalizedOps;
if (failed(applyAnalysisConversion(getOperation(), target,
std::move(patterns), config)))
@@ -1510,6 +1513,21 @@ struct TestLegalizePatternDriver
clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"),
clEnumValN(ConversionMode::Partial, "partial",
"Perform a partial conversion"))};
+
+ Option<DialectConversionFoldingMode> foldingMode{
+ *this, "test-legalize-folding-mode",
+ llvm::cl::desc("The folding mode to use with the test driver"),
+ llvm::cl::init(DialectConversionFoldingMode::BeforePatterns),
+ llvm::cl::values(clEnumValN(DialectConversionFoldingMode::Never, "never",
+ "Never attempt to fold"),
+ clEnumValN(DialectConversionFoldingMode::BeforePatterns,
+ "before-patterns",
+ "Only attempt to fold not legal operations "
+ "before applying patterns"),
+ clEnumValN(DialectConversionFoldingMode::AfterPatterns,
+ "after-patterns",
+ "Only attempt to fold not legal operations "
+ "after applying patterns"))};
};
} // namespace
|
@llvm/pr-subscribers-mlir-core Author: Jacques Pienaar (jpienaar) ChangesPreviously this only happened post checking if the op is legal, but was done unconditionally post (and before other legalization patterns). Add option to not attempt folding and one to do so as last resort. Did consider but did not add a always attempt to fold option (which would have folded whether or not legal), but removed TODO about it. Full diff: https://github.com/llvm/llvm-project/pull/152890.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f6437657c9a93..da092fb24c4ff 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1158,6 +1158,16 @@ class PDLConversionConfig final {
// ConversionConfig
//===----------------------------------------------------------------------===//
+/// An enum to control folding behavior during dialect conversion.
+enum class DialectConversionFoldingMode {
+ /// Never attempt to fold.
+ Never,
+ /// Only attempt to fold not legal operations before applying patterns.
+ BeforePatterns,
+ /// Only attempt to fold not legal operations after applying patterns.
+ AfterPatterns,
+};
+
/// Dialect conversion configuration.
struct ConversionConfig {
/// An optional callback used to notify about match failure diagnostics during
@@ -1240,6 +1250,10 @@ struct ConversionConfig {
/// your patterns do not trigger any IR rollbacks. For details, see
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
bool allowPatternRollback = true;
+
+ /// The folding mode to use during conversion.
+ DialectConversionFoldingMode foldingMode =
+ DialectConversionFoldingMode::BeforePatterns;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 08803e082b057..4aa934be2abc9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2197,15 +2197,16 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
- // If the operation isn't legal, try to fold it in-place.
- // TODO: Should we always try to do this, even if the op is
- // already legal?
- if (succeeded(legalizeWithFold(op, rewriter))) {
- LLVM_DEBUG({
- logSuccess(logger, "operation was folded");
- logger.startLine() << logLineComment;
- });
- return success();
+ // If the operation is not legal, try to fold it in-place if the folding mode
+ // is 'BeforePatterns'. 'Never' will skip this.
+ if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
+ if (succeeded(legalizeWithFold(op, rewriter))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
}
// Otherwise, we need to apply a legalization pattern to this operation.
@@ -2217,6 +2218,18 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
+ // If the operation can't be legalized via patterns, try to fold it in-place
+ // if the folding mode is 'AfterPatterns'.
+ if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
+ if (succeeded(legalizeWithFold(op, rewriter))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
+ }
+
LLVM_DEBUG({
logFailure(logger, "no matched legalization pattern");
logger.startLine() << logLineComment;
diff --git a/mlir/test/Transforms/test-legalizer-fold-after.mlir b/mlir/test/Transforms/test-legalizer-fold-after.mlir
new file mode 100644
index 0000000000000..7f80252dc9604
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-fold-after.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=after-patterns" | FileCheck %s
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+ // CHECK-NOT: op_in_place_self_fold
+ // CHECK: 97
+ %1 = "test.op_in_place_self_fold"() : () -> (i32)
+ "test.return"(%1) : (i32) -> ()
+}
diff --git a/mlir/test/Transforms/test-legalizer-fold-before.mlir b/mlir/test/Transforms/test-legalizer-fold-before.mlir
new file mode 100644
index 0000000000000..fe6e29351a5d7
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-fold-before.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=before-patterns" | FileCheck %s
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+ // CHECK: op_in_place_self_fold
+ // CHECK-SAME: folded
+ %1 = "test.op_in_place_self_fold"() : () -> (i32)
+ "test.return"(%1) : (i32) -> ()
+}
diff --git a/mlir/test/Transforms/test-legalizer-no-fold.mlir b/mlir/test/Transforms/test-legalizer-no-fold.mlir
new file mode 100644
index 0000000000000..c2d4dff2b4d3d
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-no-fold.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -test-legalize-patterns="test-legalize-folding-mode=never" | FileCheck %s
+
+// CHECK-LABEL: @remove_foldable_op(
+func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
+ // Check that op was not folded.
+ // CHECK: "test.op_with_region_fold"
+ %0 = "test.op_with_region_fold"(%arg0) ({
+ "foo.op_with_region_terminator"() : () -> ()
+ }) : (i32) -> (i32)
+ "test.return"(%0) : (i32) -> ()
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2eaad552a7a3a..fee05618c1d58 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1478,6 +1478,8 @@ def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
let results = (outs I32);
let hasFolder = 1;
}
+def : Pat<(TestOpInPlaceSelfFold:$op $_),
+ (TestOpConstant ConstantAttr<I32Attr, "97">)>;
// Test op that simply returns success.
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index eda618f5b09c6..969fef928b88b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1402,7 +1402,7 @@ struct TestLegalizePatternDriver
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
- TerminatorOp, OneRegionOp>();
+ TerminatorOp, OneRegionOp, TestOpConstant>();
target.addLegalOp(OperationName("test.legal_op", &getContext()));
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
@@ -1457,6 +1457,7 @@ struct TestLegalizePatternDriver
DumpNotifications dumpNotifications;
config.listener = &dumpNotifications;
config.unlegalizedOps = &unlegalizedOps;
+ config.foldingMode = foldingMode;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config))) {
getOperation()->emitRemark() << "applyPartialConversion failed";
@@ -1476,6 +1477,7 @@ struct TestLegalizePatternDriver
ConversionConfig config;
DumpNotifications dumpNotifications;
+ config.foldingMode = foldingMode;
config.listener = &dumpNotifications;
if (failed(applyFullConversion(getOperation(), target,
std::move(patterns), config))) {
@@ -1490,6 +1492,7 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
ConversionConfig config;
+ config.foldingMode = foldingMode;
config.legalizableOps = &legalizedOps;
if (failed(applyAnalysisConversion(getOperation(), target,
std::move(patterns), config)))
@@ -1510,6 +1513,21 @@ struct TestLegalizePatternDriver
clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"),
clEnumValN(ConversionMode::Partial, "partial",
"Perform a partial conversion"))};
+
+ Option<DialectConversionFoldingMode> foldingMode{
+ *this, "test-legalize-folding-mode",
+ llvm::cl::desc("The folding mode to use with the test driver"),
+ llvm::cl::init(DialectConversionFoldingMode::BeforePatterns),
+ llvm::cl::values(clEnumValN(DialectConversionFoldingMode::Never, "never",
+ "Never attempt to fold"),
+ clEnumValN(DialectConversionFoldingMode::BeforePatterns,
+ "before-patterns",
+ "Only attempt to fold not legal operations "
+ "before applying patterns"),
+ clEnumValN(DialectConversionFoldingMode::AfterPatterns,
+ "after-patterns",
+ "Only attempt to fold not legal operations "
+ "after applying patterns"))};
};
} // namespace
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/13890 Here is the relevant piece of the build log for the reference
|
Previously this only happened post checking if the op is legal, but was done unconditionally post (and before other legalization patterns). Add option to not attempt folding and one to do so as last resort.
Did consider but did not add a always attempt to fold option (which would have folded whether or not legal), but removed TODO about it.